"""
Description
-----------
CASA 6 imaging script for exoALMA observations. This contains a general
function, `iterative_tclean` which runs an iterative tclean using a mask based
on a shallow initial clean. Either a `robust` and `uv_taper` can be specified or
a desired circular beam size with `beam` which will optimize the `robust` and
`uv_taper` to get a circular beam, following the approach used for MAPS. 

Usage
-----
First you must activate the exoALMA virtual environment with

> source /lustre/cv/projects/exoALMA/casa_modular_6.2.1/bin/activate

then you can import the functions into your workflow. Source specific scripts are 
available on the Dataverse repository.

Authors
-------
John Ilee (Leeds)
Richard Teague (MIT)
Gianni Cataldi (NAOJ)
Ryan Loomis (NRAO)
"""

import os
import scipy.constants as sc
import numpy as np
import casatools
import casatasks
ia = casatools.image()

# Define the disks that have ACA data.

ACA_disks = ['AA_Tau',
             'DM_Tau',
             'HD_34282',
             'J1604-2130',
             'J1615-3255',
             'LkCa_15',
             'V4046_Sgr',
              ]

def iterative_tclean(ms_path, save_path, restfreq, spw, start, nchan, width,
        beam=None, robust=None, threshold=[4, 3], uv_taper='',
        weighting='briggs', gridder='standard', scales=[0, 5, 15, 35, 100],
        imsize=1024, cell='0.025arcsec', restoringbeam='common',
        initial_clean_depth=7.0, mask_kernel='0.7arcsec',
        exportfits=['.image', '.pbcor', '.mask', '.psf'],
        moments=['zeroth', 'first', 'second'], mask_moments=True,
        only_dirty=False):
    """
    A wrapper for CASA tclean to image data in an interative manner. The data
    will first be cleaned to a very shallow level (default is 7xRMS). This CLEAN
    model will then be used to define a mask by convolving it with a Gaussian
    kernel. The data will be then be cleaned down to a range of different
    thresholds, starting from the previous clean model to speed up the process.
    This is a very simplified version of the auto-masking built into CASA.

    If a `beam` value is given in [arcsec], an optimization will take place to
    find the robust value and uv-taper which results in the most circular beam
    with this FWHM. A final `imsmooth` will be run to obtain the desired beam.
    This is following the same procedure as was used for MAPS. Note that if
    `beam` is not specified, a `robust` and `uv_taper` can still be provided.

    The default `imsize=1024` and `cell='0.02arcsec'` yield a field of view of
    20" x 20" which is just smaller than the 21" primary beam at 300 GHz.

    Args:
        ms_path (str): Path to MS file to image.
        save_path (str): Path to save the imaging products to.
        restfreq (str): Line rest frequency in standard CASA form.
        spw (str): Spectral windows to image in standard CASA form.
        start (str): Starting channel to image in standard CASA form.
        nchan (int): Number of channels to image.
        width (str): Spacing between channels in standard CASA form.
        beam (float): Desired circulated beam FWHM
        robust ([optional[float]): Robust value to use for the imaging.
        threshold (optional[float/list]): Threshold value to CLEAN down to in
            [sigma]. This can either be a single value or a list of values which
            will result in multiple images cleaned to different depths. Each
            image will be started from the prior image model to save time.
        uv_taper (optional[str]): uv tapering to apply in standard CASA form.
        weighting (optional[str]): Weight scheme to use.
        griddder (optional[str]): Gridding scheme to use.
        scales (optional[list]): List of scales for multi-scale cleaning.
        imsize (optional[int]): Number of pixels for the image.
        cell (optional[str]): Pixel size for the image in standard CASA form.
        restoringbeam (optional[str]): Restoring beam to use for the imaging.
        initial_clean_depth (float): Threshold value for first CLEAN to define
            mask in [sigma]. Don't make this too low!
        mask_kernel (str): Size of Gaussian to convolve with .model to create
            mask. Bigger is more conservative.
        exportfits (optional[list]): List of outputs to be saved as FITS.
        moments (optional[str/list]): A (list of) methods to collapse the image
            cube using `bettermoments`.
        mask_moments (optional[bool]): Whether to use the CLEAN mask for the
            generation of moments with `bettermoments`.
        only_dirty (optional[bool]): If `True`, only make the dirty image.
    """

    # Check to see if and circularization is required and run if necessary.
    # After running this, both `robust` and `uv_taper` should be set.

    if beam is not None:
        if robust is not None:
            raise ValueError("Cannot request both `beam` and `robust`.")
        circularization = True
    elif robust is None:
        raise ValueError("Must specify either `beam` or `robust`.")
    else:
        circularization = False

    if circularization:
        print("Finding `robust` and `uv_taper`.")
        from circularize_beam import get_robust_and_taper
        params = get_robust_and_taper(base_ms=ms_path,
                                      beam=beam,
                                      npix=imsize,
                                      dpix=float(cell.replace('arcsec', '')))
        robust = params[0]
        uv_taper = '{:.2f}arcsec, {:.2f}arcsec, {:.1f}deg'
        uv_taper = uv_taper.format(*params[1:]).split(', ')

    assert robust is not None, "`robust` cannot be `None.`"
    assert uv_taper is not None, "`uv_taper` cannot be `None`."

    # Define the image names. Note that this string will have a spot empty
    # for the threshold value so must always be called using
    # `imname.format(threshold_value)`. There will be two names depending on
    # if the robust value was provided, or the beam size.

    imname = save_path
    if circularization:
        imname += '_beam{:.2f}'.format(beam)
    else:
        imname += '_robust{:.1f}'.format(robust)
    if type(width) == str:
        imname += '_{}'.format(width.replace('/', ''))
    else:
        imname += '_{}chans'.format(width)
    imname += '_{:.0f}sigma'

    # Initial dirty image to measure the RMS.

    print("Running initial dirty image...")
    image_string = imname.format(0.0) + '.dirty'

    casatasks.tclean(vis=ms_path,
                    imagename=image_string,
                    specmode='cube',
                    restfreq=restfreq,
                    spw=spw,
                    start=start,
                    nchan=nchan,
                    width=width,
                    deconvolver='multiscale',
                    scales=scales,
                    weighting=weighting,
                    gridder=gridder,
                    robust=robust,
                    cell=cell,
                    imsize=imsize,
                    niter=0,
                    interactive=False,
                    perchanweightdensity=False,
                    restoringbeam=restoringbeam,
                    uvtaper=uv_taper,
                    )

    # Quit here if only_dirty is specified.

    if only_dirty:
        return

    # If the resulting beam has a major beam size that's larger than the desired
    # beam, we can just remove the uv taper and make up the rest with the
    # imsmooth later. This isn't ideal, but works for the few cases where the
    # circularize_beam optimization fails.

    if circularization:
        bmaj = beam_statistics(image_string + '.image')[0]
        if bmaj > beam :
            print("WARNING: Dirty beam too big with uv-taper. Skipping taper.")
            uv_taper = ''

    # Calculate the RMS in an annulus in the outer regions of the image.

    estimated_RMS = image_statistics(image_string=image_string,
                                     mask_name=None,
                                     statistic='rms',
                                     mJy=True,
                                     )

    # Second call to tclean to clean down to a conservative threshold set
    # by `initial_clean_depth`. Defaults to 7xRMS.

    message = "Running a shallow clean down to {:.0f} sigma."
    print(message.format(initial_clean_depth))

    threshold_string = '{:.2f}mJy'.format(initial_clean_depth * estimated_RMS)
    image_string = imname.format(initial_clean_depth) + '.clean'

    casatasks.tclean(vis=ms_path,
                     imagename=image_string,
                     specmode='cube',
                     restfreq=restfreq,
                     spw=spw,
                     start=start,
                     nchan=nchan,
                     width=width,
                     deconvolver='multiscale',
                     scales=scales,
                     weighting=weighting,
                     gridder=gridder,
                     robust=robust,
                     cell=cell,
                     imsize=imsize,
                     niter=100000,
                     interactive=False,
                     perchanweightdensity=False,
                     restoringbeam=restoringbeam,
                     uvtaper=uv_taper,
                     threshold=threshold_string,
                     )

    # Convolve the model to generate the mask.

    casatasks.imsmooth(imagename=image_string + '.model',
                       outfile=image_string + '.shallow_clean_mask',
                       overwrite=True,
                       kernel='gauss',
                       major=mask_kernel,
                       minor=mask_kernel,
                       pa='0deg',
                       targetres=True,
                       )

    mask_name = image_string + '.shallow_clean_mask'
    ia.open(mask_name)
    ia.calc("iif('{}' > 0.001, 1, 0)".format(mask_name))
    ia.done()

    # Restimate the RMS of the data outside area outside the new mask.

    estimated_RMS = image_statistics(image_string=image_string,
                                     mask_name=mask_name,
                                     statistic='rms',
                                     mJy=True,
                                     )

    # Apply the third round of cleaning down to the requested depth. This will
    # cycle through a list of decreasing thresholds, each time starting from
    # the model from the previous image.

    threshold = np.sort(np.unique(np.atleast_1d(threshold)))[::-1]
    assert threshold.ndim == 1, "`threshold` must be 1D."

    start_model = image_string + '.model'

    for threshold_tmp in threshold:

        message = "Running a deep clean down to {:.0f} sigma."
        print(message.format(threshold_tmp))

        threshold_string = '{:.2f}mJy'.format(threshold_tmp * estimated_RMS)
        image_string = imname.format(threshold_tmp) + '.clean'

        casatasks.tclean(vis=ms_path,
                         imagename=image_string,
                         specmode='cube',
                         restfreq=restfreq,
                         spw=spw,
                         start=start,
                         nchan=nchan,
                         width=width,
                         deconvolver='multiscale',
                         scales=scales,
                         weighting=weighting,
                         gridder=gridder,
                         robust=robust,
                         cell=cell,
                         imsize=imsize,
                         niter=100000,
                         interactive=False,
                         perchanweightdensity=False,
                         restoringbeam=restoringbeam,
                         uvtaper=uv_taper,
                         threshold=threshold_string,
                         mask=mask_name,
                         startmodel=start_model,
                         )

        # Apply final circularization. Note here we have to save to a temporary
        # file and copy over to avoid overwriting (even though overwrite is
        # True).

        if circularization:
            casatasks.imsmooth(imagename=image_string + '.image',
                               outfile=image_string + '.tmp.image',
                               overwrite=True,
                               kernel='gauss',
                               major=beam,
                               minor=beam,
                               pa='0deg',
                               targetres=True,
                               )
            os.system('mv {} {}'.format(image_string + '.image',
                                        image_string + '.unsmooth.image'))
            os.system('mv {} {}'.format(image_string + '.tmp.image',
                                        image_string + '.image'))

        start_model = image_string + '.model'

        # Calculate some statistics on the final image.
        # Note: records incorrect values if imsmooth used above.  Disabled for now.

        # write_history(image_string=image_string,
        #              mask_name=mask_name,
        #              estimated_RMS=estimated_RMS,
        #              threshold=threshold_tmp,
        #              mask_kernel=mask_kernel,
        #              )

        # If JvM correction is to be added, do it here.

        # Apply the primary beam correction.

        casatasks.impbcor(imagename=image_string + '.image',
                          pbimage=image_string + '.pb',
                          outfile=image_string + '.pbcor',
                          overwrite=True,
                          )

        # Export the images as FITS.

        for ext in exportfits:
            casatasks.exportfits(imagename=image_string + ext,
                                fitsimage=image_string + ext + '.fits',
                                dropstokes=True,
                                overwrite=True,
                                history=True,
                                )

        # Calculate the moment maps.

        if nchan > 1:
            for moment in np.atleast_1d(moments):
                cmd = 'bettermoments {}'.format(image_string + '.image.fits')
                cmd += ' -method {}'.format(moment)
                if mask_moments and '.mask' in exportfits:
                    cmd += ' -mask {}'.format(image_string + '.mask.fits')
                os.system(cmd)

    return


def image_statistics(image_string, mask_name=None, statistic='rms', mJy=True):
    """
    Returns some basic image statistics. If a mask is not provided through
    `mask_name` an annular mask that has an inner and outer radius of 80% and
    90% of the image field of view, respectively. If the field of view is such
    that this is all NaNs, this is reduced by a factor of 10% each tiem. Note
    that for no mask, `mask_name=''` must be given, i.e., be explicit in that no
    masked is required.

    Args:
        image_string (str): Path to the CLEAN image.
        mask_name (optional[str]): Mask to use. Note that `mask_name=''` is
            different to `mask_name=None` in that the former applies no mask
            while the latter creates an annular mask.
        statistic (optional[str]): Statistic to return. Defaults to the RMS.
        mJy (optional[bool]): Whether to return the statistic in units of [mJy].

    Returns:
        value (float): The requested statitics in units of [mJy] if `mJy=True`
            else in [Jy].
    """
    image_string = image_string.replace('.image', '')
    if mask_name is None:
        ia.open(image_string + '.image')
        image_size = np.squeeze(ia.getregion()).shape[0]
        pixel_size = np.abs(ia.summary(list=False)['incr'][0]) / sc.arcsec
        ia.close()
        m0 = image_size / 2
        mask = 'annulus[[{}pix,{}pix], ["{}arcsec", "{}arcsec"]]'
        for mf in np.arange(0.5, 0.95, 0.05)[::-1]:
            mi = mf * m0 * pixel_size
            mo = (mf + 0.05) * m0 * pixel_size
            stats = casatasks.imstat(imagename=image_string + '.image',
                                     region=mask.format(m0, m0, mi, mo),
                                     )
            if stats[statistic].size > 0:
                break
            else:
                print("WARNINING: NaN in noise annulus; reducing annulus size.")
    elif mask_name != '':
        stats = casatasks.imstat(imagename=image_string + '.image',
                                 mask='"{}" < 0.1'.format(mask_name),
                                 )
    else:
        stats = casatasks.imstat(imagename=image_string + '.image')
    return stats[statistic][0] * 1e3 if mJy else stats[statistic][0]


def beam_statistics(image_string):
    """
    Gather the beam statistics, including the epsilon factor as defined in Eqn.
    15 in MAPS II (Czekala et al., 2021). Some of the code here was based on
    code used in MPOL.

    Args:
        image_string (str): Path to the CLEAN image.

    Returns:
        bmaj (float): Beam major axis in [arcsec].
        bmin (float): Beam minor axis in [arcsec].
        bphi (float): Beam position angle in [deg].
        beam_area (float): CLEAN beam area in [arcsec^2].
        psf_area (float): True PSF area in [arcsec^2].
        epsilon (float): Ratio of `beam_area` to `psf_area`.

    References:
        MPOL: https://github.com/MPoL-dev/MPoL.
        Czekala et al. (2021): https://ui.adsabs.harvard.edu/abs/2021ApJS..257....2C/abstract
    """

    # Open the PSF file, grab the data and summary.

    image_string = image_string.replace('.image', '')
    ia.open(image_string + '.psf')
    psf = np.squeeze(ia.getregion()).copy()
    psf_summary = ia.summary(list=False)
    ia.close()

    assert np.all(psf_summary['axisunits'][:2] == 'rad')
    assert np.abs(psf_summary['incr'][0]) == np.abs(psf_summary['incr'][1])
    pixel_size = np.abs(psf_summary['incr'][0]) / sc.arcsec

    try:
        beam_data = psf_summary['restoringbeam']
    except KeyError:
        beam_data = psf_summary['perplanebeams']['beams']['*0']['*0']

    for ax in ('major', 'minor'):
        assert beam_data[ax]['unit'] == 'arcsec'
    assert beam_data['positionangle']['unit'] == 'deg'

    bmaj = beam_data['major']['value']
    bmin = beam_data['minor']['value']
    bphi = beam_data['positionangle']['value']

    # If multiple PSFs are found, return the first channel with a non-zero PSF.

    if psf.ndim == 3:
        psf_sum = [np.sum(psf[:, :, c]) for c in range(psf.shape[2])]
        psf = psf[:, :, np.flatnonzero(psf_sum)[0]]
    assert np.sum(psf) != 0, 'PSF area is 0.'
    assert psf.ndim == 2, f'psf.ndim = {psf.ndim}'

    # Check that the peak of the PSF is in the image center.

    center = np.unravel_index(np.argmax(psf), psf.shape)
    for i in range(2):
        expected_center = (psf.shape[i] - 1) / 2
        assert np.abs(center[i]-expected_center) < 2

    # Define the radial offset from the PSF peak and polar angle.

    X,Y = np.meshgrid(np.arange(psf.shape[0])[:, None] - center[0],
                      np.arange(psf.shape[1])[None, :] - center[1],
                      indexing='ij')
    radius = np.hypot(X, Y)
    theta = np.arctan2(Y, X)

    # For a range of wedges, find the first null. Pixels inside this are
    # considered part of the beam, while pixels outside are set to zero.
    # Note there's a slightly overlap in azimuthal wedges to ensure all pixels
    # are considered.

    wedges = np.linspace(-np.pi, np.pi, 20)
    inside_beam = np.ones_like(psf, dtype=bool)
    for theta_min, theta_max in zip(wedges[:-1], wedges[1:]):
        wedge_mask = np.logical_and(theta >= theta_min - 0.1,
                                    theta < theta_max + 0.1)
        null_mask = np.logical_and(wedge_mask, psf < 0.0)
        null = np.nanmin(np.where(null_mask, radius, np.nan))
        beyond_null = np.logical_and(wedge_mask, radius > null)
        inside_beam = np.where(beyond_null, False, inside_beam)
    psf = np.where(inside_beam, psf, 0.0)

    # Calculate the epsilon value following Eqn. 15 of Czekala et al. (2021).

    beam_area = 2.0 * np.pi * FWHM_to_sigma(bmaj) * FWHM_to_sigma(bmin)
    psf_area = np.sum(psf) * pixel_size**2
    epsilon = beam_area / psf_area

    # Return the values.

    return bmaj, bmin, bphi, beam_area, psf_area, epsilon


def FWHM_to_sigma(FWHM):
    """Convert a FWHM to the standard deviation for a Gaussian."""
    return FWHM / (2.0 * np.sqrt(2.0 * np.log(2.0)))


def write_history(image_string, mask_name, estimated_RMS, threshold,
        mask_kernel):
    """
    Include history in the image header.

    Args:
        image_string (str): Path to the CLEAN image.
        mask_name (str): Name of the mask used for creating the image.
        estimated_RMS (float): Estimated RMS used for calculating the threshold.
        threshold (float): The number of sigma used for the threshold.
        mask_kernel (float): Kernel size used for making the mask.
    """

    image_max = image_statistics(image_string=image_string,
                                 mask_name='',
                                 statistic='max',
                                 mJy=True,
                                 )

    image_RMS = image_statistics(image_string=image_string,
                                 mask_name=mask_name,
                                 statistic='rms',
                                 mJy=True,
                                )

    out = beam_statistics(image_string=image_string)
    *beam_axes, beam_area, psf_area, _ = out

    # Create strings containing custom logging information.

    info = ['TCLEAN Properties',
            'RMS used for CLEAN = {:.2f} mJy/beam'.format(estimated_RMS),
            'CLEAN threshold = {:.2f} mJy'.format(threshold),
            'Kernel size used to create mask = {}'.format(mask_kernel),
            'Final RMS = {:.2f} mJy/beam'.format(image_RMS),
            'Peak SNR = {:.2f}'.format(image_max / image_RMS),
            'PSF Properties',
            'CLEAN beam = {:.3f}" x {:.3f}" ({:.1f})'.format(*beam_axes),
            'CLEAN beam area = {:.3f}'.format(beam_area),
            'PSF area (within first null) = {:.3f}'.format(psf_area),
            'Epsilon = {:.3f}'.format(beam_area / psf_area),
            ]

    # Append this to the history.

    for info_line in info:
        casatasks.imhistory(image_string + '.image',
                            mode="append",
                            message="exoALMA: " + info_line)



def get_nchan_vstart(vlsr,vrng,width):
    """ Calculates the starting velocity and number of channel
    for a given LSR velocity, velocity range and channel width.

    vlsr, vrng are given in km/s and width in m/s

    return vstart [km/s], and nchan

    The function ensures that there is a channel centred on vlsr and
    that the number of channels is odd

    """

    w = width * 1e-3 #km/s

    # Selecting number of channels (and forcing it to be odd)
    nchan=int(np.ceil(2.0 * vrng / w))
    if nchan%2 == 0:
        nchan -=1
        if nchan//2 * w < vrng:
            nchan +=2

    # Starting velocity to ensure we have the central channel at vlsr
    vstart = vlsr - nchan//2 * w

    return vstart, nchan


def create_velocity_grid(vstart,nchan,width):
    """
    Return the velocity grid that will be used for imaging

    This function is mostly to check that get_nchan_vstar works

    vstart in km/s
    nchan : number of channel (int)
    width in m/2

    return v [km/s]

    """

    return vstart + np.arange(nchan)*width*1e-3